<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
  <meta charset="utf-8">
  <meta http-equiv="X-UA-Compatible" content="IE=edge">
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  
  
  <link rel="shortcut icon" href="../../img/favicon.ico">
  <title>Baby RNN - Keras 中文文档</title>
  <link href='https://fonts.googleapis.com/css?family=Lato:400,700|Roboto+Slab:400,700|Inconsolata:400,700' rel='stylesheet' type='text/css'>

  <link rel="stylesheet" href="../../css/theme.css" type="text/css" />
  <link rel="stylesheet" href="../../css/theme_extra.css" type="text/css" />
  <link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
  
  <script>
    // Current page data
    var mkdocs_page_name = "Baby RNN";
    var mkdocs_page_input_path = "examples/babi_rnn.md";
    var mkdocs_page_url = "/zh/examples/babi_rnn/";
  </script>
  
  <script src="../../js/jquery-2.1.1.min.js" defer></script>
  <script src="../../js/modernizr-2.8.3.min.js" defer></script>
  <script src="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/highlight.min.js"></script>
  <script>hljs.initHighlightingOnLoad();</script> 
  
  <script>
      (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
      (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
      m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
      })(window,document,'script','https://www.google-analytics.com/analytics.js','ga');

      ga('create', 'UA-61785484-1', 'keras.io');
      ga('send', 'pageview');
  </script>
  
</head>

<body class="wy-body-for-nav" role="document">

  <div class="wy-grid-for-nav">

    
    <nav data-toggle="wy-nav-shift" class="wy-nav-side stickynav">
      <div class="wy-side-nav-search">
        <a href="../.." class="icon icon-home"> Keras 中文文档</a>
        <div role="search">
  <form id ="rtd-search-form" class="wy-form" action="../../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" title="Type search term here" />
  </form>
</div>
      </div>

      <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
	<ul class="current">
	  
          
            <li class="toctree-l1">
		
    <a class="" href="../..">主页</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../why-use-keras/">为什么选择 Keras?</a>
	    </li>
          
            <li class="toctree-l1">
		
    <span class="caption-text">快速开始</span>
    <ul class="subnav">
                <li class="">
                    
    <a class="" href="../../getting-started/sequential-model-guide/">Sequential 顺序模型指引</a>
                </li>
                <li class="">
                    
    <a class="" href="../../getting-started/functional-api-guide/">函数式 API 指引</a>
                </li>
                <li class="">
                    
    <a class="" href="../../getting-started/faq/">FAQ 常见问题解答</a>
                </li>
    </ul>
	    </li>
          
            <li class="toctree-l1">
		
    <span class="caption-text">模型</span>
    <ul class="subnav">
                <li class="">
                    
    <a class="" href="../../models/about-keras-models/">关于 Keras 模型</a>
                </li>
                <li class="">
                    
    <a class="" href="../../models/sequential/">Sequential 顺序模型 API</a>
                </li>
                <li class="">
                    
    <a class="" href="../../models/model/">函数式 API</a>
                </li>
    </ul>
	    </li>
          
            <li class="toctree-l1">
		
    <span class="caption-text">Layers</span>
    <ul class="subnav">
                <li class="">
                    
    <a class="" href="../../layers/about-keras-layers/">关于 Keras 网络层</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/core/">核心网络层</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/convolutional/">卷积层 Convolutional</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/pooling/">池化层 Pooling</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/local/">局部连接层 Locally-connected</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/recurrent/">循环层 Recurrent</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/embeddings/">嵌入层 Embedding</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/merge/">融合层 Merge</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/advanced-activations/">高级激活层 Advanced Activations</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/normalization/">标准化层 Normalization</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/noise/">噪声层 Noise</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/wrappers/">层封装器 wrappers</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/writing-your-own-keras-layers/">编写你自己的层</a>
                </li>
    </ul>
	    </li>
          
            <li class="toctree-l1">
		
    <span class="caption-text">数据预处理</span>
    <ul class="subnav">
                <li class="">
                    
    <a class="" href="../../preprocessing/sequence/">序列预处理</a>
                </li>
                <li class="">
                    
    <a class="" href="../../preprocessing/text/">文本预处理</a>
                </li>
                <li class="">
                    
    <a class="" href="../../preprocessing/image/">图像预处理</a>
                </li>
    </ul>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../losses/">损失函数 Losses</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../metrics/">评估标准 Metrics</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../optimizers/">优化器 Optimizers</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../activations/">激活函数 Activations</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../callbacks/">回调函数 Callbacks</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../datasets/">常用数据集 Datasets</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../applications/">应用 Applications</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../backend/">后端 Backend</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../initializers/">初始化 Initializers</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../regularizers/">正则化 Regularizers</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../constraints/">约束 Constraints</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../visualization/">可视化 Visualization</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../scikit-learn-api/">Scikit-learn API</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../utils/">工具</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../contributing/">贡献</a>
	    </li>
          
            <li class="toctree-l1">
		
    <span class="caption-text">经典样例</span>
    <ul class="subnav">
                <li class="">
                    
    <a class="" href="../addition_rnn/">Addition RNN</a>
                </li>
                <li class=" current">
                    
    <a class="current" href="./">Baby RNN</a>
    <ul class="subnav">
            
    <li class="toctree-l3"><a href="#_1">基于故事和问题训练两个循环神经网络。</a></li>
    
        <ul>
        
            <li><a class="toctree-l4" href="#_2">注意</a></li>
        
        </ul>
    

    </ul>
                </li>
                <li class="">
                    
    <a class="" href="../babi_memnn/">Baby MemNN</a>
                </li>
                <li class="">
                    
    <a class="" href="../cifar10_cnn/">CIFAR-10 CNN</a>
                </li>
                <li class="">
                    
    <a class="" href="../cifar10_cnn_capsule/">CIFAR-10 CNN-Capsule</a>
                </li>
                <li class="">
                    
    <a class="" href="../cifar10_cnn_tfaugment2d/">CIFAR-10 CNN with augmentation (TF)</a>
                </li>
                <li class="">
                    
    <a class="" href="../cifar10_resnet/">CIFAR-10 ResNet</a>
                </li>
                <li class="">
                    
    <a class="" href="../conv_filter_visualization/">Convolution filter visualization</a>
                </li>
                <li class="">
                    
    <a class="" href="../image_ocr/">Image OCR</a>
                </li>
                <li class="">
                    
    <a class="" href="../imdb_bidirectional_lstm/">Bidirectional LSTM</a>
                </li>
    </ul>
	    </li>
          
        </ul>
      </div>
      &nbsp;
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">

      
      <nav class="wy-nav-top" role="navigation" aria-label="top navigation">
        <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
        <a href="../..">Keras 中文文档</a>
      </nav>

      
      <div class="wy-nav-content">
        <div class="rst-content">
          <div role="navigation" aria-label="breadcrumbs navigation">
  <ul class="wy-breadcrumbs">
    <li><a href="../..">Docs</a> &raquo;</li>
    
      
        
          <li>经典样例 &raquo;</li>
        
      
    
    <li>Baby RNN</li>
    <li class="wy-breadcrumbs-aside">
      
        <a href="https://github.com/keras-team/keras-docs-zh/edit/master/docs/examples/babi_rnn.md"
          class="icon icon-github"> Edit on GitHub</a>
      
    </li>
  </ul>
  <hr/>
</div>
          <div role="main">
            <div class="section">
              
                <h1 id="_1">基于故事和问题训练两个循环神经网络。</h1>
<p>两者的合并向量将用于回答一系列 bAbI 任务。</p>
<p>这些结果与 Weston 等人提供的 LSTM 模型的结果相当：<a href="http://arxiv.org/abs/1502.05698">Towards AI-Complete Question Answering: A Set of Prerequisite Toy Tasks</a>。</p>
<table>
<thead>
<tr>
<th>Task Number</th>
<th>FB LSTM Baseline</th>
<th>Keras QA</th>
</tr>
</thead>
<tbody>
<tr>
<td>QA1 - Single Supporting Fact</td>
<td>50</td>
<td>52.1</td>
</tr>
<tr>
<td>QA2 - Two Supporting Facts</td>
<td>20</td>
<td>37.0</td>
</tr>
<tr>
<td>QA3 - Three Supporting Facts</td>
<td>20</td>
<td>20.5</td>
</tr>
<tr>
<td>QA4 - Two Arg. Relations</td>
<td>61</td>
<td>62.9</td>
</tr>
<tr>
<td>QA5 - Three Arg. Relations</td>
<td>70</td>
<td>61.9</td>
</tr>
<tr>
<td>QA6 - yes/No Questions</td>
<td>48</td>
<td>50.7</td>
</tr>
<tr>
<td>QA7 - Counting</td>
<td>49</td>
<td>78.9</td>
</tr>
<tr>
<td>QA8 - Lists/Sets</td>
<td>45</td>
<td>77.2</td>
</tr>
<tr>
<td>QA9 - Simple Negation</td>
<td>64</td>
<td>64.0</td>
</tr>
<tr>
<td>QA10 - Indefinite Knowledge</td>
<td>44</td>
<td>47.7</td>
</tr>
<tr>
<td>QA11 - Basic Coreference</td>
<td>72</td>
<td>74.9</td>
</tr>
<tr>
<td>QA12 - Conjunction</td>
<td>74</td>
<td>76.4</td>
</tr>
<tr>
<td>QA13 - Compound Coreference</td>
<td>94</td>
<td>94.4</td>
</tr>
<tr>
<td>QA14 - Time Reasoning</td>
<td>27</td>
<td>34.8</td>
</tr>
<tr>
<td>QA15 - Basic Deduction</td>
<td>21</td>
<td>32.4</td>
</tr>
<tr>
<td>QA16 - Basic Induction</td>
<td>23</td>
<td>50.6</td>
</tr>
<tr>
<td>QA17 - Positional Reasoning</td>
<td>51</td>
<td>49.1</td>
</tr>
<tr>
<td>QA18 - Size Reasoning</td>
<td>52</td>
<td>90.8</td>
</tr>
<tr>
<td>QA19 - Path Finding</td>
<td>8</td>
<td>9.0</td>
</tr>
<tr>
<td>QA20 - Agent's Motivations</td>
<td>91</td>
<td>90.7</td>
</tr>
</tbody>
</table>
<p>有关 bAbI 项目的相关资源，请参考: https://research.facebook.com/researchers/1543934539189348</p>
<h3 id="_2">注意</h3>
<ul>
<li>使用默认的单词、句子和查询向量尺寸，GRU 模型得到了以下效果：</li>
<li>20 轮迭代后，在 QA1 上达到了 52.1% 的测试准确率（在 CPU 上每轮迭代 2 秒）；</li>
<li>
<p>20 轮迭代后，在 QA2 上达到了 37.0% 的测试准确率（在 CPU 上每轮迭代 16 秒）。</p>
<p>相比之下，Facebook的论文中 LSTM baseline 的准确率分别是 50% 和 20%。</p>
</li>
<li>
<p>这个任务并不是笼统地单独去解析问题。这应该可以提高准确率，且是合并两个 RNN 的一次较好实践。</p>
</li>
<li>
<p>故事和问题的 RNN 之间不共享词向量（词嵌入）。</p>
</li>
<li>
<p>注意观察 1000 个训练样本（en-10k）到 10,000 个的准确度如何变化。使用 1000 是为了与原始论文进行对比。</p>
</li>
<li>
<p>尝试使用 GRU, LSTM 和 JZS1-3，因为它们会产生微妙的不同结果。</p>
</li>
<li>
<p>长度和噪声（即「无用」的故事内容）会影响 LSTM/GRU 提供正确答案的能力。在只提供事实的情况下，这些 RNN可以在许多任务上达到 100% 的准确性。 使用注意力过程的记忆网络和神经网络可以有效地搜索这些噪声以找到相关的语句，从而大大提高性能。这在 QA2 和 QA3 上变得尤为明显，两者都远远显著于 QA1。</p>
</li>
</ul>
<pre><code class="python">from __future__ import print_function
from functools import reduce
import re
import tarfile

import numpy as np

from keras.utils.data_utils import get_file
from keras.layers.embeddings import Embedding
from keras import layers
from keras.layers import recurrent
from keras.models import Model
from keras.preprocessing.sequence import pad_sequences


def tokenize(sent):
    '''返回包含标点符号的句子的标记。

    &gt;&gt;&gt; tokenize('Bob dropped the apple. Where is the apple?')
    ['Bob', 'dropped', 'the', 'apple', '.', 'Where', 'is', 'the', 'apple', '?']
    '''
    return [x.strip() for x in re.split(r'(\W+)?', sent) if x.strip()]


def parse_stories(lines, only_supporting=False):
    '''解析 bAbi 任务格式中提供的故事

    如果 only_supporting 为 true，
    则只保留支持答案的句子。
    '''
    data = []
    story = []
    for line in lines:
        line = line.decode('utf-8').strip()
        nid, line = line.split(' ', 1)
        nid = int(nid)
        if nid == 1:
            story = []
        if '\t' in line:
            q, a, supporting = line.split('\t')
            q = tokenize(q)
            if only_supporting:
                # 只选择相关的子故事
                supporting = map(int, supporting.split())
                substory = [story[i - 1] for i in supporting]
            else:
                # 提供所有子故事
                substory = [x for x in story if x]
            data.append((substory, q, a))
            story.append('')
        else:
            sent = tokenize(line)
            story.append(sent)
    return data


def get_stories(f, only_supporting=False, max_length=None):
    '''给定文件名，读取文件，检索故事，
    然后将句子转换为一个独立故事。

    如果提供了 max_length,
    任何长于 max_length 的故事都将被丢弃。
    '''
    data = parse_stories(f.readlines(), only_supporting=only_supporting)
    flatten = lambda data: reduce(lambda x, y: x + y, data)
    data = [(flatten(story), q, answer) for story, q, answer in data
            if not max_length or len(flatten(story)) &lt; max_length]
    return data


def vectorize_stories(data, word_idx, story_maxlen, query_maxlen):
    xs = []
    xqs = []
    ys = []
    for story, query, answer in data:
        x = [word_idx[w] for w in story]
        xq = [word_idx[w] for w in query]
        # let's not forget that index 0 is reserved
        y = np.zeros(len(word_idx) + 1)
        y[word_idx[answer]] = 1
        xs.append(x)
        xqs.append(xq)
        ys.append(y)
    return (pad_sequences(xs, maxlen=story_maxlen),
            pad_sequences(xqs, maxlen=query_maxlen), np.array(ys))

RNN = recurrent.LSTM
EMBED_HIDDEN_SIZE = 50
SENT_HIDDEN_SIZE = 100
QUERY_HIDDEN_SIZE = 100
BATCH_SIZE = 32
EPOCHS = 20
print('RNN / Embed / Sent / Query = {}, {}, {}, {}'.format(RNN,
                                                           EMBED_HIDDEN_SIZE,
                                                           SENT_HIDDEN_SIZE,
                                                           QUERY_HIDDEN_SIZE))

try:
    path = get_file('babi-tasks-v1-2.tar.gz',
                    origin='https://s3.amazonaws.com/text-datasets/'
                           'babi_tasks_1-20_v1-2.tar.gz')
except:
    print('Error downloading dataset, please download it manually:\n'
          '$ wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2'
          '.tar.gz\n'
          '$ mv tasks_1-20_v1-2.tar.gz ~/.keras/datasets/babi-tasks-v1-2.tar.gz')
    raise

# 默认 QA1 任务，1000 样本
# challenge = 'tasks_1-20_v1-2/en/qa1_single-supporting-fact_{}.txt'
# QA1 任务，10,000 样本
# challenge = 'tasks_1-20_v1-2/en-10k/qa1_single-supporting-fact_{}.txt'
# QA2 任务，1000 样本
challenge = 'tasks_1-20_v1-2/en/qa2_two-supporting-facts_{}.txt'
# QA2 任务，10,000 样本
# challenge = 'tasks_1-20_v1-2/en-10k/qa2_two-supporting-facts_{}.txt'
with tarfile.open(path) as tar:
    train = get_stories(tar.extractfile(challenge.format('train')))
    test = get_stories(tar.extractfile(challenge.format('test')))

vocab = set()
for story, q, answer in train + test:
    vocab |= set(story + q + [answer])
vocab = sorted(vocab)

# 保留 0 以留作 pad_sequences 进行 masking
vocab_size = len(vocab) + 1
word_idx = dict((c, i + 1) for i, c in enumerate(vocab))
story_maxlen = max(map(len, (x for x, _, _ in train + test)))
query_maxlen = max(map(len, (x for _, x, _ in train + test)))

x, xq, y = vectorize_stories(train, word_idx, story_maxlen, query_maxlen)
tx, txq, ty = vectorize_stories(test, word_idx, story_maxlen, query_maxlen)

print('vocab = {}'.format(vocab))
print('x.shape = {}'.format(x.shape))
print('xq.shape = {}'.format(xq.shape))
print('y.shape = {}'.format(y.shape))
print('story_maxlen, query_maxlen = {}, {}'.format(story_maxlen, query_maxlen))

print('Build model...')

sentence = layers.Input(shape=(story_maxlen,), dtype='int32')
encoded_sentence = layers.Embedding(vocab_size, EMBED_HIDDEN_SIZE)(sentence)
encoded_sentence = RNN(SENT_HIDDEN_SIZE)(encoded_sentence)

question = layers.Input(shape=(query_maxlen,), dtype='int32')
encoded_question = layers.Embedding(vocab_size, EMBED_HIDDEN_SIZE)(question)
encoded_question = RNN(QUERY_HIDDEN_SIZE)(encoded_question)

merged = layers.concatenate([encoded_sentence, encoded_question])
preds = layers.Dense(vocab_size, activation='softmax')(merged)

model = Model([sentence, question], preds)
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

print('Training')
model.fit([x, xq], y,
          batch_size=BATCH_SIZE,
          epochs=EPOCHS,
          validation_split=0.05)

print('Evaluation')
loss, acc = model.evaluate([tx, txq], ty,
                           batch_size=BATCH_SIZE)
print('Test loss / test accuracy = {:.4f} / {:.4f}'.format(loss, acc))
</code></pre>
              
            </div>
          </div>
          <footer>
  
    <div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
      
        <a href="../babi_memnn/" class="btn btn-neutral float-right" title="Baby MemNN">Next <span class="icon icon-circle-arrow-right"></span></a>
      
      
        <a href="../addition_rnn/" class="btn btn-neutral" title="Addition RNN"><span class="icon icon-circle-arrow-left"></span> Previous</a>
      
    </div>
  

  <hr/>

  <div role="contentinfo">
    <!-- Copyright etc -->
    
  </div>

  Built with <a href="http://www.mkdocs.org">MkDocs</a> using a <a href="https://github.com/snide/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
      
        </div>
      </div>

    </section>

  </div>

  <div class="rst-versions" role="note" style="cursor: pointer">
    <span class="rst-current-version" data-toggle="rst-current-version">
      
          <a href="https://github.com/keras-team/keras-docs-zh/" class="fa fa-github" style="float: left; color: #fcfcfc"> GitHub</a>
      
      
        <span><a href="../addition_rnn/" style="color: #fcfcfc;">&laquo; Previous</a></span>
      
      
        <span style="margin-left: 15px"><a href="../babi_memnn/" style="color: #fcfcfc">Next &raquo;</a></span>
      
    </span>
</div>
    <script>var base_url = '../..';</script>
    <script src="../../js/theme.js" defer></script>
      <script src="../../search/main.js" defer></script>

</body>
</html>
